#include "ColliderDefinitions.cginc"
#include "ContactHandling.cginc"
#include "DistanceFunctions.cginc"
#include "Simplex.cginc"
#include "Bounds.cginc"
#include "SolverParameters.cginc"
#include "Optimization.cginc"

#pragma kernel GenerateContacts

StructuredBuffer<float4> positions;
StructuredBuffer<quaternion> orientations;
StructuredBuffer<float4> principalRadii;
StructuredBuffer<float4> velocities;

StructuredBuffer<int> simplices;
StructuredBuffer<aabb> simplexBounds; // bounding box of each simplex.

StructuredBuffer<transform> transforms;
StructuredBuffer<shape> shapes;

// heightfield data:
StructuredBuffer<HeightFieldHeader> heightFieldHeaders;
StructuredBuffer<float> heightFieldSamples;

StructuredBuffer<uint2> contactPairs;
StructuredBuffer<int> contactOffsetsPerType;

RWStructuredBuffer<contact> contacts;
RWStructuredBuffer<uint> dispatchBuffer;

StructuredBuffer<transform> worldToSolver;

uint maxContacts;
float deltaTime;

[numthreads(128, 1, 1)]
void GenerateContacts (uint3 id : SV_DispatchThreadID) 
{
    uint i = id.x;

    // entry #11 in the dispatch buffer is the amount of pairs for the first shape type.
    if (i >= dispatchBuffer[11 + 4 * HEIGHTMAP_SHAPE]) return; 

    int firstPair = contactOffsetsPerType[HEIGHTMAP_SHAPE];
    int simplexIndex = contactPairs[firstPair + i].x;
    int colliderIndex = contactPairs[firstPair + i].y;
    shape s = shapes[colliderIndex];

    if (s.dataIndex < 0) return;

    HeightFieldHeader header = heightFieldHeaders[s.dataIndex];
    
    Heightfield fieldShape;
    fieldShape.colliderToSolver = worldToSolver[0].Multiply(transforms[colliderIndex]);
    fieldShape.s = s;
    
    // invert a full matrix here to accurately represent collider bounds scale.
    float4x4 solverToCollider = Inverse(TRS(fieldShape.colliderToSolver.translation.xyz, fieldShape.colliderToSolver.rotation, fieldShape.colliderToSolver.scale.xyz));
    aabb simplexBound = simplexBounds[simplexIndex].Transformed(solverToCollider);

    int simplexSize;
    int simplexStart = GetSimplexStartAndSize(simplexIndex, simplexSize);

    int resolutionU = (int)s.center.x;
    int resolutionV = (int)s.center.y;

    // calculate terrain cell size:
    float cellWidth = s.size.x / (resolutionU - 1);
    float cellHeight = s.size.z / (resolutionV - 1);

    // calculate particle bounds min/max cells:
    int2 min_ = int2((int)floor(simplexBound.min_[0] / cellWidth), (int)floor(simplexBound.min_[2] / cellHeight));
    int2 max_ = int2((int)floor(simplexBound.max_[0] / cellWidth), (int)floor(simplexBound.max_[2] / cellHeight));

    for (int su = min_[0]; su <= max_[0]; ++su)
    {
        if (su >= 0 && su < resolutionU - 1)
        {
            for (int sv = min_[1]; sv <= max_[1]; ++sv)
            {
                if (sv >= 0 && sv < resolutionV - 1)
                {
                    // calculate neighbor sample indices:
                    int csu1 = clamp(su + 1, 0, resolutionU - 1);
                    int csv1 = clamp(sv + 1, 0, resolutionV - 1);

                    // sample heights:
                    float h1 = heightFieldSamples[header.firstSample + sv * resolutionU + su] * s.size.y;
                    float h2 = heightFieldSamples[header.firstSample + sv * resolutionU + csu1] * s.size.y;
                    float h3 = heightFieldSamples[header.firstSample + csv1 * resolutionU + su] * s.size.y;
                    float h4 = heightFieldSamples[header.firstSample + csv1 * resolutionU + csu1] * s.size.y;

                    if (h1 < 0) continue;
                    h1 = abs(h1);
                    h2 = abs(h2);
                    h3 = abs(h3);
                    h4 = abs(h4);

                    float min_x = su * s.size.x / (resolutionU - 1);
                    float max_x = csu1 * s.size.x / (resolutionU - 1);
                    float min_z = sv * s.size.z / (resolutionV - 1);
                    float max_z = csv1 * s.size.z / (resolutionV - 1);

                    float4 convexPoint;
                    float4 simplexBary = BarycenterForSimplexOfSize(simplexSize);

                    // ------contact against the first triangle------:
                    float4 v1 = float4(min_x, h3, max_z, 0);
                    float4 v2 = float4(max_x, h4, max_z, 0);
                    float4 v3 = float4(min_x, h1, min_z, 0);

                    fieldShape.tri.Cache(v1, v2, v3);
                    fieldShape.triNormal.xyz = normalizesafe(cross((v2 - v1).xyz, (v3 - v1).xyz));

                    SurfacePoint colliderPoint = Optimize(fieldShape, positions, orientations, principalRadii,
                                                          simplices, simplexStart, simplexSize, simplexBary, convexPoint, surfaceCollisionIterations, surfaceCollisionTolerance);

                    float4 velocity = FLOAT4_ZERO;
                    float simplexRadius = 0;
                    int j;
                    for (j = 0; j < simplexSize; ++j)
                    {
                        int particleIndex = simplices[simplexStart + j];
                        simplexRadius += principalRadii[particleIndex].x * simplexBary[j];
                        velocity += velocities[particleIndex] * simplexBary[j];
                    }

                    float dAB = dot(convexPoint - colliderPoint.pos, colliderPoint.normal);
                    float vel = dot(velocity, colliderPoint.normal);

                    if (vel * deltaTime + dAB <= simplexRadius + s.contactOffset + collisionMargin)
                    {
                        uint count = contacts.IncrementCounter();
                        if (count < maxContacts)
                        {
                            contact c = (contact)0;

                            c.pointB = colliderPoint.pos;
                            c.normal = colliderPoint.normal * fieldShape.s.isInverted();
                            c.pointA = simplexBary;
                            c.bodyA = simplexIndex;
                            c.bodyB = colliderIndex;

                            contacts[count] = c;

                            InterlockedMax(dispatchBuffer[0],(count + 1) / 128 + 1);
                            InterlockedMax(dispatchBuffer[3], count + 1);
                        }
                    }

                    // ------contact against the second triangle------:
                    v1 = float4(min_x, h1, min_z, 0);
                    v2 = float4(max_x, h4, max_z, 0);
                    v3 = float4(max_x, h2, min_z, 0);

                    fieldShape.tri.Cache(v1, v2, v3);
                    fieldShape.triNormal.xyz = normalizesafe(cross((v2 - v1).xyz, (v3 - v1).xyz));

                    colliderPoint = Optimize(fieldShape, positions, orientations, principalRadii,
                                    simplices, simplexStart, simplexSize, simplexBary, convexPoint, surfaceCollisionIterations, surfaceCollisionTolerance);

                    velocity = FLOAT4_ZERO;
                    simplexRadius = 0;
                    for (j = 0; j < simplexSize; ++j)
                    {
                        int particleIndex = simplices[simplexStart + j];
                        simplexRadius += principalRadii[particleIndex].x * simplexBary[j];
                        velocity += velocities[particleIndex] * simplexBary[j];
                    }

                    dAB = dot(convexPoint - colliderPoint.pos, colliderPoint.normal);
                    vel = dot(velocity, colliderPoint.normal);

                    if (vel * deltaTime + dAB <= simplexRadius + s.contactOffset + collisionMargin)
                    {
                        uint count = contacts.IncrementCounter();
                        if (count < maxContacts)
                        {
                            contact c = (contact)0;

                            c.pointB = colliderPoint.pos;
                            c.normal = colliderPoint.normal * fieldShape.s.isInverted();
                            c.pointA = simplexBary;
                            c.bodyA = simplexIndex;
                            c.bodyB = colliderIndex;

                            contacts[count] = c;

                            InterlockedMax(dispatchBuffer[0],(count + 1) / 128 + 1);
                            InterlockedMax(dispatchBuffer[3], count + 1);
                        }
                    }
                }
            }
        }
    }
}